import os
import pandas as pd
import matplotlib.pyplot as plt
import argparse

# Initialize the argument parser
parser = argparse.ArgumentParser(description='Count total reads in FASTQ files and plot read distribution.')
parser.add_argument('directory', type=str, help='Path to the directory containing FASTQ files')
parser.add_argument('plot_filename', type=str, help='Filename for the output plot')
parser.add_argument('csv_filename', type=str, help='Filename for the output CSV file')
args = parser.parse_args()

# Get the directory and plot filename from command-line arguments
directory = args.directory
plot_filename = args.plot_filename
csv_filename = args.csv_filename

# Initialize an empty list to store filenames and read counts
file_info = []

# Initialize a variable to store the total number of reads
total_reads = 0

# Iterate over each file in the directory
for filename in os.listdir(directory):
    # Make sure we're only looking at .fastq files
    if filename.endswith(".fastq"):
        with open(os.path.join(directory, filename)) as f:
            # Count the number of lines in each file
            num_lines = sum(1 for line in f)
            # Divide by 4 to get the number of reads (each read is 4 lines in a .fastq file)
            num_reads = num_lines // 4
            # Remove the .fastq extension from the filename
            filename_without_extension = filename.replace(".fastq", "")
            file_info.append((filename_without_extension, num_reads))
            # Add the number of reads to the total
            total_reads += num_reads

# Create a pandas DataFrame from the data
df = pd.DataFrame(file_info, columns=['Filename', 'Number of Reads'])

# Save the DataFrame to a CSV file
df.to_csv(csv_filename, index=False)

# Sort the read counts in descending order for plotting
read_rank = list(df['Number of Reads'])
read_rank.sort(reverse=True)
x_axis = range(1, len(read_rank) + 1)

# Plot the read distribution on a log-log scale
plt.loglog(x_axis, read_rank)
plt.xlabel('Barcodes')
plt.ylabel('mapped UMI')
plt.xlim([1, len(read_rank) + 1])

# Save the plot to the specified file
plt.savefig(plot_filename)

# Print the total number of reads
print(f"Total number of reads: {total_reads}")
